package com.atguigu.dga.assess.assessor.cacl;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.atguigu.dga.assess.assessor.AssessorTemplate;
import com.atguigu.dga.assess.bean.AssessParam;
import com.atguigu.dga.assess.bean.GovernanceAssessDetail;
import com.atguigu.dga.assess.bean.MyFileld;
import com.atguigu.dga.assess.bean.TDsTaskInstance;
import com.atguigu.dga.assess.service.TDsTaskInstanceService;
import com.atguigu.dga.config.MetaConstant;
import com.atguigu.dga.meta.bean.TableMetaInfo;
import com.atguigu.dga.util.CacheUtil;
import com.atguigu.dga.util.SqlParser;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.google.common.collect.Sets;
import lombok.Data;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.parse.ASTNode;
import org.apache.hadoop.hive.ql.parse.HiveParser;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.math.BigDecimal;
import java.text.ParseException;
import java.util.*;
import java.util.stream.Collectors;

/**
 * Created by Smexy on 2023/10/31

 sql语句没有任何join|groupby|union all，且where过滤中没有非分区字段。符合以上情况给0分，其余给10分。

 1.检查导数的sql语句，如果有join|groupby|union all，就是复杂加工，给10分
    没有
 2.检测where过滤中有没有非分区字段
        没有，就是简单加工，给0分。
         有，复杂加工，给10分
    a) 获取where过滤的所有的字段名
            举例: where  dt,name,id
    b)  获取当前sql查询的表的所有分区信息
            举例：  从 a表查询，和从b表查询
                a，按照dt分区，又按照id分区
                b，按照name分区

--------------------------------
    分析sql。
        1.过滤dim_date和ods层是load，没有insertsql，不存在join。
        2.可以通过t_ds_task_instance的task_params 获取当天某张表运行的sql语句。
        3.解析sql
                检测join/group by /union all这些操作可以通过正则表达式检测
                想获取where过滤的字段和查询的表名，正则无法完成，正则只能检测，不能提取。
                        以上操作都可以通过遍历sql语法树来获取。



 */
@Component("SIMPLE_PROCESS")
public class CheckSimpleProcess extends AssessorTemplate
{

    @Autowired
    private CacheUtil cacheUtil;

    @Autowired
    private TDsTaskInstanceService taskInstanceService;
    @Override
    protected void assess(AssessParam param, GovernanceAssessDetail detail) throws Exception {
        //获取层级和表名
        String dwLevel = param.getMetaInfo().getTableMetaInfoExtra().getDwLevel();
        String name = CacheUtil.getKey(param.getMetaInfo());
        System.out.println(name);
        String schemaName = param.getMetaInfo().getSchemaName();
        //把要过滤的表做成一个集合
        HashSet<String> filter = Sets.newHashSet("gmall.dim_date", "gmall.dim_date_tmp");
        //过滤
        if (filter.contains(name) || MetaConstant.DW_LEVEL_ODS.equals(dwLevel)){
            return ;
        }
        //获取到当前表执行的sql
        TDsTaskInstance taskInstance = taskInstanceService.getOne(
            new QueryWrapper<TDsTaskInstance>()
                .eq("name", schemaName + "." + param.getMetaInfo().getTableName())  //找到Task的名字
                .eq("date(start_time)", param.getAssessDate()) //只取今天的
                .eq("state", MetaConstant.TASK_STATE_SUCCESS)
        );

       // System.out.println(taskInstance.getName());
        String rawScript = JSON.parseObject(taskInstance.getTaskParams()).getString("rawScript");
        String sql = parseRawScript(rawScript);
       //  sql = "select * from dim_user_zip t1  where t1.name = 'jack' ";
        //解析sql为语法树
        MyDispatcher myDispatcher = new MyDispatcher();
        SqlParser.parse(sql,myDispatcher);
        //解析完成后，获取结果
        Set<String> complexOperators = myDispatcher.getComplexOperators();
        Set<String> whereFields = myDispatcher.getWhereFields();
        Set<String> selectTableNames = myDispatcher.getSelectTableNames();

        JSONObject jsonObject = new JSONObject();
        jsonObject.put("复杂操作",JSON.toJSONString(complexOperators));
        jsonObject.put("查询的表",JSON.toJSONString(selectTableNames));
        jsonObject.put("where过滤的字段",JSON.toJSONString(whereFields));

        //所有表的元数据信息缓存
        Map<String, TableMetaInfo> cacheMap = cacheUtil.getTableMetaInfoMap();
        //判断是否是简单加工
        if (complexOperators.isEmpty()){
            //没有join,group by ,union 继续判断 where后是否有非分区字段。 获取到所查询的表的分区字段
            Set<String> selectTablePartitonCol = new HashSet<>();

            //把所有所查询表的所有分区字段名字加入到 set集合中
            selectTableNames.stream()
                .forEach(tableName -> {
                    // 需要判断null，当前tableName可能是起的别名，那么hive中没有对应的这种表
                    TableMetaInfo tableMetaInfo = cacheMap.get(schemaName + "." + tableName);
                    if (tableMetaInfo != null){
                        String json = tableMetaInfo.getPartitionColNameJson();
                        List<MyFileld> myFilelds = JSON.parseArray(json, MyFileld.class);
                        //把当前所查询表的所有分区字段名字加入到 set集合中
                        selectTablePartitonCol.addAll(myFilelds.stream().map(MyFileld::getName).collect(Collectors.toList()));
                    }
                } );

            /*
                差集比较
                   selectTablePartitonCol:  [ dt,a,b ]
                    whereFields :  [dt ,a,b, c ]
                        whereFields.差集(selectTablePartitonCol)
                        结果赋值给whereFields，再判断结果
             */
            whereFields.removeAll(selectTablePartitonCol);

            if (whereFields.size() == 0){
                //where之后的字段全部是分区字段 是简单加工
                assessScore(BigDecimal.ZERO,"没有复杂查询，且where后全部是分区字段!",null,detail,false,null);
            }

        }

        //纯粹为了调试，把所有表解析的结果都放入备注中
        detail.setAssessComment(JSON.toJSONString(jsonObject));
    }

    /*
        CTE查询:
            set x = b;
            with xx as ()
            insert xxx
            select xxx;
            set x = a;

        非CTE查询:
                set x = b;
                insert xxxx


         截取的起始位置:    没有with找insert
     */
    private String parseRawScript(String rawScript) {

        int start = rawScript.indexOf("with");

        //确定截取的起始位置
        if (start == -1){
            //没有找到with，就不是CTE查询，找insert
            start = rawScript.indexOf("insert");
        }

        //确定截取的结束位置。找距离开头最近的;
        int end = rawScript.indexOf(";", start);
        if (end == -1){
            end = rawScript.length();
        }
        return rawScript.substring(start,end);
    }

    @Data
    private class MyDispatcher implements Dispatcher{

        //用于收集当前表导数的sql的复杂运算符
        private Set<String> complexOperators = new HashSet<>();
        //收集where过滤的字段
        private Set<String> whereFields = new HashSet<>();
        //收集查询的表名
        private Set<String> selectTableNames = new HashSet<>();

        //使用这个集合判断当前节点是不是复杂查询
        Set<Integer> complexOperatorSet= Sets.newHashSet(
            HiveParser.TOK_JOIN,  //join 包含通过where 连接的情况
            HiveParser.TOK_GROUPBY,       //  group by
            HiveParser.TOK_LEFTOUTERJOIN,       //  left join
            HiveParser.TOK_RIGHTOUTERJOIN,     //   right join
            HiveParser.TOK_FULLOUTERJOIN,     // full join
            HiveParser.TOK_FUNCTION,     //count(1)
            HiveParser.TOK_FUNCTIONDI,  //count(distinct xx)
            HiveParser.TOK_FUNCTIONSTAR, // count(*)
            HiveParser.TOK_SELECTDI,  // distinct
            HiveParser.TOK_UNIONALL   // union ,union all
        );

        //比较运算符
        Set<String> compareOperators= Sets.newHashSet("=",">","<",">=","<=" ,"<>" ,"!=" ,"like","not like"); // in / not in 属于函数计算

        /*
         1.通过语法树，如何判断当前sql有没有 join|group by|union all
        只需要从树的根节点向下遍历，如果出现了子(孙，重孙..)节点的名字有 TOK_GROUPBY，说明这个sql一定是有group by
        只需要从树的根节点向下遍历，如果出现了子(孙，重孙..)节点的名字有 TOK_UNIONALL，说明这个sql一定是有union all
        只需要从树的根节点向下遍历，如果出现了子(孙，重孙..)节点的名字有 TOK_LEFTOUTJOIN，说明这个sql一定是有left join

 2.通过语法树，如何能知道这个sql查询了哪些表?
        遍历树，找TOK_TABNAME。
            如果是库名.表名(TOK_TABNAME 有两个子节点)，找第二个子节点。
            如果是表名(TOK_TABNAME 有一个子节点)，找第一个子节点。

 3.通过语法树，如何能知道这个sql的where部分过滤了哪些字段?
        遍历树，先找到 TOK_WHERE 节点，再向下遍历。
        a)找到TOK_WHERE，获取它的子节点
        b) 看子节点是不是比较运算符(=，<,>)
                是。遍历当前节点的子节点
                    判断是不是.
                        是.，获取.这个节点的第二个孩子
                        不是.，判断是不是TOK_TABLE_OR_COL,是的话取它的第一个孩子

                不是，基本上都是逻辑运算符(and or ),需要向下遍历，继续判断
     */
          /*
            会被GraphWalker在遍历每一个节点时都执行.
                Node nd就代表当前节点，从这个对象中获取非常丰富的信息。
                    示例：
                        name:24,type:24,text:dim_user_zip
                        name:1061,type:1061,text:TOK_TABNAME
                        name:1062,type:1062,text:TOK_TABREF
                    节点的name,type都是一个整数，提前在 HiveParser中定义好了。
                    节点的text，是我们要获取的信息。

                    name可以直接从Node获取，但是type和text必须强转为ASTNode才能获取
         */
        @Override
        public Object dispatch(Node nd, Stack<Node> stack, Object... nodeOutputs) throws SemanticException {

            ASTNode node = (ASTNode) nd;

            //看当前节点是不是一个复杂运算符
            if (complexOperatorSet.contains(node.getType())){
                complexOperators.add(node.getText());
            }

            //看当前节点是不是一个表名的节点
            if (HiveParser.TOK_TABNAME == node.getType()){
                ArrayList<Node> children = node.getChildren();
                //库名.表名
                if (children.size() == 2){
                    selectTableNames.add(((ASTNode)children.get(1)).getText());
                }else {
                    //没写库名
                    selectTableNames.add(((ASTNode)node.getChild(0)).getText());
                }
            }

            //获取where过滤的字段
            if (HiveParser.TOK_WHERE == node.getType()){
                extractWhereFields(node);
            }
            return null;
        }

        /**
         *
         * @param node  TOK_WHERE节点
         */
        private void extractWhereFields(ASTNode node) {
            //获取下一级的子节点
            ArrayList<Node> children = node.getChildren();
            //退出条件 如果当前节点是最有一级，退出
            if (children == null || children.isEmpty()){
                return ;
            }
            for (Node child : children) {
                ASTNode astChild = (ASTNode) child;
                //判断是不是比较运算符
                if (compareOperators.contains(astChild.getText())){
                    ArrayList<Node> compareOperatorChildren = astChild.getChildren();
                    for (Node compareOperatorChild : compareOperatorChildren) {
                        ASTNode astCompareOperatorChild = (ASTNode) compareOperatorChild;
                        if (HiveParser.DOT == astCompareOperatorChild.getType()){
                            whereFields.add(((ASTNode)astCompareOperatorChild.getChild(1)).getText());
                        }else if (HiveParser.TOK_TABLE_OR_COL == astCompareOperatorChild.getType()){
                            whereFields.add(((ASTNode)astCompareOperatorChild.getChild(0)).getText());
                        }
                    }
                }else {
                    //不是比较运算符，递归知道获取到下一级的比较运算符
                    extractWhereFields(astChild);
                }
            }

        }
    }
}
